{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# C - Loading, Saving and Freezing Embeddings\n",
    "\n",
    "This notebook will cover: how to load custom word embeddings in TorchText, how to save all the embeddings we learn during training and how to freeze/unfreeze embeddings during training. \n",
    "\n",
    "## Loading Custom Embeddings\n",
    "\n",
    "First, lets look at loading a custom set of embeddings.\n",
    "\n",
    "Your embeddings need to be formatted so each line starts with the word followed by the values of the embedding vector, all space separated. All vectors need to have the same number of elements.\n",
    "\n",
    "Let's look at the custom embeddings provided by these tutorials. These are 20-dimensional embeddings for 7 words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "good 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n",
      "great 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n",
      "awesome 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n",
      "bad -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0\n",
      "terrible -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0\n",
      "awful -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0\n",
      "kwyjibo 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5\n",
      "\n"
     ]
    }
   ],
   "source": [
    "with open('custom_embeddings/embeddings.txt', 'r') as f:\n",
    "    print(f.read())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's setup the fields."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchtext.legacy import data\n",
    "\n",
    "SEED = 1234\n",
    "\n",
    "torch.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "TEXT = data.Field(tokenize = 'spacy')\n",
    "LABEL = data.LabelField(dtype = torch.float)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we'll load our dataset and create the validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchtext.legacy import datasets\n",
    "import random\n",
    "\n",
    "train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
    "\n",
    "train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can only load our custom embeddings after they have been turned into a `Vectors` object.\n",
    "\n",
    "We create a `Vector` object by passing it the location of the embeddings (`name`), a location for the cached embeddings (`cache`) and a function that will later initialize tokens in our embeddings that aren't within our dataset (`unk_init`). As have done in previous notebooks, we have initialized these to $\\mathcal{N}(0,1)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/7 [00:00<?, ?it/s]\n"
     ]
    }
   ],
   "source": [
    "import torchtext.vocab as vocab\n",
    "\n",
    "custom_embeddings = vocab.Vectors(name = 'custom_embeddings/embeddings.txt',\n",
    "                                  cache = 'custom_embeddings',\n",
    "                                  unk_init = torch.Tensor.normal_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To check the embeddings have loaded correctly we can print out the words loaded from our custom embedding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'good': 0, 'great': 1, 'awesome': 2, 'bad': 3, 'terrible': 4, 'awful': 5, 'kwyjibo': 6}\n"
     ]
    }
   ],
   "source": [
    "print(custom_embeddings.stoi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also directly print out the embedding values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,\n",
      "          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,\n",
      "          1.0000,  1.0000,  1.0000,  1.0000],\n",
      "        [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,\n",
      "          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,\n",
      "          1.0000,  1.0000,  1.0000,  1.0000],\n",
      "        [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,\n",
      "          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,\n",
      "          1.0000,  1.0000,  1.0000,  1.0000],\n",
      "        [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,\n",
      "         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,\n",
      "         -1.0000, -1.0000, -1.0000, -1.0000],\n",
      "        [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,\n",
      "         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,\n",
      "         -1.0000, -1.0000, -1.0000, -1.0000],\n",
      "        [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,\n",
      "         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,\n",
      "         -1.0000, -1.0000, -1.0000, -1.0000],\n",
      "        [ 0.5000, -0.5000,  0.5000, -0.5000,  0.5000, -0.5000,  0.5000, -0.5000,\n",
      "          0.5000, -0.5000,  0.5000, -0.5000,  0.5000, -0.5000,  0.5000, -0.5000,\n",
      "          0.5000, -0.5000,  0.5000, -0.5000]])\n"
     ]
    }
   ],
   "source": [
    "print(custom_embeddings.vectors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then build our vocabulary, passing our `Vectors` object.\n",
    "\n",
    "Note that the `unk_init` should be declared when creating our `Vectors`, and not here!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_VOCAB_SIZE = 25_000\n",
    "\n",
    "TEXT.build_vocab(train_data, \n",
    "                 max_size = MAX_VOCAB_SIZE, \n",
    "                 vectors = custom_embeddings)\n",
    "\n",
    "LABEL.build_vocab(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now our vocabulary vectors for the words in our custom embeddings should match what we loaded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1.])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TEXT.vocab.vectors[TEXT.vocab.stoi['good']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
       "        -1., -1., -1., -1., -1., -1.])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TEXT.vocab.vectors[TEXT.vocab.stoi['bad']]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Words that were in our custom embeddings but not in our dataset vocabulary are initialized by the `unk_init` function we passed earlier, $\\mathcal{N}(0,1)$. They are also the same size as our custom embeddings (20-dimensional)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.1117, -0.4966,  0.1631, -0.8817,  0.2891,  0.4899, -0.3853, -0.7120,\n",
       "         0.6369, -0.7141, -1.0831, -0.5547, -1.3248,  0.6970, -0.6631,  1.2158,\n",
       "        -2.5273,  1.4778, -0.1696, -0.9919])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TEXT.vocab.vectors[TEXT.vocab.stoi['kwjibo']]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The rest of the set-up is the same as it is when using the GloVe vectors, with the next step being to set-up the iterators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
    "    (train_data, valid_data, test_data), \n",
    "    batch_size = BATCH_SIZE,\n",
    "    device = device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we define our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n",
    "                 dropout, pad_idx):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)\n",
    "        \n",
    "        self.convs = nn.ModuleList([\n",
    "                                    nn.Conv2d(in_channels = 1, \n",
    "                                              out_channels = n_filters, \n",
    "                                              kernel_size = (fs, embedding_dim)) \n",
    "                                    for fs in filter_sizes\n",
    "                                    ])\n",
    "        \n",
    "        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, text):\n",
    "        \n",
    "        #text = [sent len, batch size]\n",
    "        \n",
    "        text = text.permute(1, 0)\n",
    "                \n",
    "        #text = [batch size, sent len]\n",
    "        \n",
    "        embedded = self.embedding(text)\n",
    "                \n",
    "        #embedded = [batch size, sent len, emb dim]\n",
    "        \n",
    "        embedded = embedded.unsqueeze(1)\n",
    "        \n",
    "        #embedded = [batch size, 1, sent len, emb dim]\n",
    "        \n",
    "        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n",
    "            \n",
    "        #conv_n = [batch size, n_filters, sent len - filter_sizes[n]]\n",
    "        \n",
    "        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n",
    "        \n",
    "        #pooled_n = [batch size, n_filters]\n",
    "        \n",
    "        cat = self.dropout(torch.cat(pooled, dim = 1))\n",
    "\n",
    "        #cat = [batch size, n_filters * len(filter_sizes)]\n",
    "            \n",
    "        return self.fc(cat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then initialize our model, making sure `EMBEDDING_DIM` is the same as our custom embedding dimensionality, i.e. 20."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_DIM = len(TEXT.vocab)\n",
    "EMBEDDING_DIM = 20\n",
    "N_FILTERS = 100\n",
    "FILTER_SIZES = [3,4,5]\n",
    "OUTPUT_DIM = 1\n",
    "DROPOUT = 0.5\n",
    "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
    "\n",
    "model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have a lot less parameters in this model due to the smaller embedding size used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model has 524,641 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f'The model has {count_parameters(model):,} trainable parameters')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we initialize our embedding layer to use our vocabulary vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1117, -0.4966,  0.1631,  ...,  1.4778, -0.1696, -0.9919],\n",
       "        [-0.5675, -0.2772, -2.1834,  ...,  0.8504,  1.0534,  0.3692],\n",
       "        [-0.0552, -0.6125,  0.7500,  ..., -0.1261, -1.6770,  1.2068],\n",
       "        ...,\n",
       "        [ 0.5383, -0.1504,  1.6720,  ..., -0.3857, -1.0168,  0.1849],\n",
       "        [ 2.5640, -0.8564, -0.0219,  ..., -0.3389,  0.2203, -1.6119],\n",
       "        [ 0.1203,  1.5286,  0.6824,  ...,  0.3330, -0.6704,  0.5883]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeddings = TEXT.vocab.vectors\n",
    "\n",
    "model.embedding.weight.data.copy_(embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we initialize the unknown and padding token embeddings to all zeros."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n",
    "\n",
    "model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n",
    "model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Following standard procedure, we create our optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define our loss function (criterion)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then place the loss function and the model on the GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = criterion.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the function to calculate accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def binary_accuracy(preds, y):\n",
    "    rounded_preds = torch.round(torch.sigmoid(preds))\n",
    "    correct = (rounded_preds == y).float()\n",
    "    acc = correct.sum() / len(correct)\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then implement our training function..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for batch in iterator:\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "                \n",
    "        predictions = model(batch.text).squeeze(1)\n",
    "        \n",
    "        loss = criterion(predictions, batch.label)\n",
    "        \n",
    "        acc = binary_accuracy(predictions, batch.label)\n",
    "        \n",
    "        loss.backward()\n",
    "            \n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        epoch_acc += acc.item()\n",
    "                \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "...evaluation function..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, iterator, criterion):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "    \n",
    "        for batch in iterator:\n",
    "            \n",
    "            predictions = model(batch.text).squeeze(1)\n",
    "            \n",
    "            loss = criterion(predictions, batch.label)\n",
    "            \n",
    "            acc = binary_accuracy(predictions, batch.label)\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "            epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "...and our helpful function that tells us how long an epoch takes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def epoch_time(start_time, end_time):\n",
    "    elapsed_time = end_time - start_time\n",
    "    elapsed_mins = int(elapsed_time / 60)\n",
    "    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
    "    return elapsed_mins, elapsed_secs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've finally reached training our model!\n",
    "\n",
    "## Freezing and Unfreezing Embeddings\n",
    "\n",
    "We're going to train our model for 10 epochs. During the first 5 epochs we are going to freeze the weights (parameters) of our embedding layer. For the last 10 epochs we'll allow our embeddings to be trained. \n",
    "\n",
    "Why would we ever want to do this? Sometimes the pre-trained word embeddings we use will already be good enough and won't need to be fine-tuned with our model. If we keep the embeddings frozen then we don't have to calculate the gradients and update the weights for these parameters, giving us faster training times. This doesn't really apply for the model used here, but we're mainly covering it to show how it's done. Another reason is that if our model has a large amount of parameters it may make training difficult, so by freezing our pre-trained embeddings we reduce the amount of parameters needing to be learned.\n",
    "\n",
    "To freeze the embedding weights, we set `model.embedding.weight.requires_grad` to `False`. This will cause no gradients to be calculated for the weights in the embedding layer, and thus no parameters will be updated when `optimizer.step()` is called.\n",
    "\n",
    "Then, during training we check if `FREEZE_FOR` (which we set to 5) epochs have passed. If they have then we set `model.embedding.weight.requires_grad` to `True`, telling PyTorch that we should calculate gradients in the embedding layer and update them with our optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 01 | Epoch Time: 0m 7s | Frozen? True\n",
      "\tTrain Loss: 0.724 | Train Acc: 53.68%\n",
      "\t Val. Loss: 0.658 |  Val. Acc: 62.27%\n",
      "Epoch: 02 | Epoch Time: 0m 6s | Frozen? True\n",
      "\tTrain Loss: 0.670 | Train Acc: 59.36%\n",
      "\t Val. Loss: 0.626 |  Val. Acc: 67.51%\n",
      "Epoch: 03 | Epoch Time: 0m 6s | Frozen? True\n",
      "\tTrain Loss: 0.636 | Train Acc: 63.62%\n",
      "\t Val. Loss: 0.592 |  Val. Acc: 70.22%\n",
      "Epoch: 04 | Epoch Time: 0m 6s | Frozen? True\n",
      "\tTrain Loss: 0.613 | Train Acc: 66.22%\n",
      "\t Val. Loss: 0.573 |  Val. Acc: 71.77%\n",
      "Epoch: 05 | Epoch Time: 0m 6s | Frozen? True\n",
      "\tTrain Loss: 0.599 | Train Acc: 67.40%\n",
      "\t Val. Loss: 0.569 |  Val. Acc: 70.86%\n",
      "Epoch: 06 | Epoch Time: 0m 7s | Frozen? False\n",
      "\tTrain Loss: 0.577 | Train Acc: 69.53%\n",
      "\t Val. Loss: 0.520 |  Val. Acc: 76.17%\n",
      "Epoch: 07 | Epoch Time: 0m 7s | Frozen? False\n",
      "\tTrain Loss: 0.544 | Train Acc: 72.21%\n",
      "\t Val. Loss: 0.487 |  Val. Acc: 78.03%\n",
      "Epoch: 08 | Epoch Time: 0m 7s | Frozen? False\n",
      "\tTrain Loss: 0.507 | Train Acc: 74.96%\n",
      "\t Val. Loss: 0.450 |  Val. Acc: 80.02%\n",
      "Epoch: 09 | Epoch Time: 0m 7s | Frozen? False\n",
      "\tTrain Loss: 0.469 | Train Acc: 77.72%\n",
      "\t Val. Loss: 0.420 |  Val. Acc: 81.79%\n",
      "Epoch: 10 | Epoch Time: 0m 7s | Frozen? False\n",
      "\tTrain Loss: 0.426 | Train Acc: 80.28%\n",
      "\t Val. Loss: 0.392 |  Val. Acc: 82.76%\n"
     ]
    }
   ],
   "source": [
    "N_EPOCHS = 10\n",
    "FREEZE_FOR = 5\n",
    "\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "#freeze embeddings\n",
    "model.embedding.weight.requires_grad = unfrozen = False\n",
    "\n",
    "for epoch in range(N_EPOCHS):\n",
    "\n",
    "    start_time = time.time()\n",
    "    \n",
    "    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
    "    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
    "    \n",
    "    end_time = time.time()\n",
    "\n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    \n",
    "    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s | Frozen? {not unfrozen}')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')\n",
    "    \n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), 'tutC-model.pt')\n",
    "    \n",
    "    if (epoch + 1) >= FREEZE_FOR:\n",
    "        #unfreeze embeddings\n",
    "        model.embedding.weight.requires_grad = unfrozen = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another option would be to unfreeze the embeddings whenever the validation loss stops increasing using the following code snippet instead of the `FREEZE_FOR` condition:\n",
    "    \n",
    "```python\n",
    "if valid_loss < best_valid_loss:\n",
    "    best_valid_loss = valid_loss\n",
    "    torch.save(model.state_dict(), 'tutC-model.pt')\n",
    "else:\n",
    "    #unfreeze embeddings\n",
    "    model.embedding.weight.requires_grad = unfrozen = True\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 0.396 | Test Acc: 82.36%\n"
     ]
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('tutC-model.pt'))\n",
    "\n",
    "test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
    "\n",
    "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving Embeddings\n",
    "\n",
    "We might want to re-use the embeddings we have trained here with another model. To do this, we'll write a function that will loop through our vocabulary, getting the word and embedding for each word, writing them to a text file in the same format as our custom embeddings so they can be used with TorchText again.\n",
    "\n",
    "Currently, TorchText Vectors seem to have issues with loading certain unicode words, so we skip these by only writing words without unicode symbols. **If you know a better solution to this then let me know**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def write_embeddings(path, embeddings, vocab):\n",
    "    \n",
    "    with open(path, 'w') as f:\n",
    "        for i, embedding in enumerate(tqdm(embeddings)):\n",
    "            word = vocab.itos[i]\n",
    "            #skip words with unicode symbols\n",
    "            if len(word) != len(word.encode()):\n",
    "                continue\n",
    "            vector = ' '.join([str(i) for i in embedding.tolist()])\n",
    "            f.write(f'{word} {vector}\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll write our embeddings to `trained_embeddings.txt`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25002/25002 [00:00<00:00, 38085.03it/s]\n"
     ]
    }
   ],
   "source": [
    "write_embeddings('custom_embeddings/trained_embeddings.txt', \n",
    "                 model.embedding.weight.data, \n",
    "                 TEXT.vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To double check they've written correctly, we can load them as `Vectors`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 17550/24946 [00:00<00:00, 87559.48it/s]\n"
     ]
    }
   ],
   "source": [
    "trained_embeddings = vocab.Vectors(name = 'custom_embeddings/trained_embeddings.txt',\n",
    "                                   cache = 'custom_embeddings',\n",
    "                                   unk_init = torch.Tensor.normal_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's print out the first 5 rows of our loaded vectors and the same from our model's embeddings weights, checking they are the same values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.2573, -0.2088,  0.2413, -0.1549,  0.1940, -0.1466, -0.2195, -0.1011,\n",
      "         -0.1327,  0.1803,  0.2369, -0.2182,  0.1543, -0.2150, -0.0699, -0.0430,\n",
      "         -0.1958, -0.0506, -0.0059, -0.0024],\n",
      "        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "          0.0000,  0.0000,  0.0000,  0.0000],\n",
      "        [-0.1427, -0.4414,  0.7181, -0.5751, -0.3183,  0.0552, -1.6764, -0.3177,\n",
      "          0.6592,  1.6143, -0.1920, -0.1881, -0.4321, -0.8578,  0.5266,  0.5243,\n",
      "         -0.7083, -0.0048, -1.4680,  1.1425],\n",
      "        [-0.4700, -0.0363,  0.0560, -0.7394, -0.2412, -0.4197, -1.7096,  0.9444,\n",
      "          0.9633,  0.3703, -0.2243, -1.5279, -1.9086,  0.5718, -0.5721, -0.6015,\n",
      "          0.3579, -0.3834,  0.8079,  1.0553],\n",
      "        [-0.7055,  0.0954,  0.4646, -1.6595,  0.1138,  0.2208, -0.0220,  0.7397,\n",
      "         -0.1153,  0.3586,  0.3040, -0.6414, -0.1579, -0.2738, -0.6942,  0.0083,\n",
      "          1.4097,  1.5225,  0.6409,  0.0076]])\n"
     ]
    }
   ],
   "source": [
    "print(trained_embeddings.vectors[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.2573, -0.2088,  0.2413, -0.1549,  0.1940, -0.1466, -0.2195, -0.1011,\n",
      "         -0.1327,  0.1803,  0.2369, -0.2182,  0.1543, -0.2150, -0.0699, -0.0430,\n",
      "         -0.1958, -0.0506, -0.0059, -0.0024],\n",
      "        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "          0.0000,  0.0000,  0.0000,  0.0000],\n",
      "        [-0.1427, -0.4414,  0.7181, -0.5751, -0.3183,  0.0552, -1.6764, -0.3177,\n",
      "          0.6592,  1.6143, -0.1920, -0.1881, -0.4321, -0.8578,  0.5266,  0.5243,\n",
      "         -0.7083, -0.0048, -1.4680,  1.1425],\n",
      "        [-0.4700, -0.0363,  0.0560, -0.7394, -0.2412, -0.4197, -1.7096,  0.9444,\n",
      "          0.9633,  0.3703, -0.2243, -1.5279, -1.9086,  0.5718, -0.5721, -0.6015,\n",
      "          0.3579, -0.3834,  0.8079,  1.0553],\n",
      "        [-0.7055,  0.0954,  0.4646, -1.6595,  0.1138,  0.2208, -0.0220,  0.7397,\n",
      "         -0.1153,  0.3586,  0.3040, -0.6414, -0.1579, -0.2738, -0.6942,  0.0083,\n",
      "          1.4097,  1.5225,  0.6409,  0.0076]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "print(model.embedding.weight.data[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All looks good! The only difference between the two is the removal of the ~50 words in the vocabulary that contain unicode symbols."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}